# -*-Python-*-

# The vanilla transformer model but with transparent attention between the encoder and
# decoder. See https://arxiv.org/abs/1808.07561.

#include 'models/bi_v1_tiny.gin'
import mesh_tensorflow.transformer.transformer_layers
import mesh_tensorflow.transformer.universal_transformer
Unitransformer.shared_embedding_and_softmax_weights = False
dropout_rate = 0.0
# UT layer is repeated num_layers times, num_layers should be set to one for UT.
num_layers = 1
encoder/make_layer_stack.block_scope = False
encoder/transformer.make_layer_stack.layers = [
   ('self_attention',
    @mesh_tensorflow.transformer.transformer_layers.SelfAttention),
   ('dense_relu_dense',
    @mesh_tensorflow.transformer.transformer_layers.DenseReluDense),
]
decoder/make_layer_stack.block_scope = False
decoder/transformer.make_layer_stack.layers = [
    ('self_attention',
     @mesh_tensorflow.transformer.transformer_layers.SelfAttention),
    ('enc_dec_attention',
     @mesh_tensorflow.transformer.transformer_layers.EncDecAttention),
    ('dense_relu_dense',
     @mesh_tensorflow.transformer.transformer_layers.DenseReluDense),
]
encoder/transformer.make_layer_stack.layer_stack_cls = @universal_transformer.UTLayerStack
decoder/transformer.make_layer_stack.layer_stack_cls = @universal_transformer.UTLayerStack
UTLayerStack.act_type = "basic"
UTLayerStack.recurrence_type = "basic"
UTLayerStack.act_max_steps = 24
UTLayerStack.act_epsilon = 0.05 # original=0.01
UTLayerStack.num_rec_steps = 24
UTLayerStack.num_inrecurrence_layers = 1
UTLayerStack.position_start_index = None
UTLayerStack.add_or_concat_timing_signal = "add"
UTLayerStack.step_timing_signal_type = "learned"
UTLayerStack.add_position_timing_signal = True
UTLayerStack.add_step_timing_signal = True
UTLayerStack.mix_with_transformer_before_ut = True
UTLayerStack.mix_with_transformer_after_ut = True
UTLayerStack.gates_inputs = "i"
UTLayerStack.gate_ffn_layer = "dense"
UTLayerStack.couple_carry_transform_gates = True
UTLayerStack.use_gated_transformer = True
UTLayerStack.gating_type = "gru"
